{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "03_Sentence_Classification_with_BERT.ipynb",
      "provenance": [],
      "machine_shape": "hm",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jalammar/jalammar.github.io/blob/master/notebooks/nlp/03_Sentence_Classification_with_BERT.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "izA3-6kffbdT",
        "colab_type": "text"
      },
      "source": [
        "# 03 - Sentence Classification with BERT\n",
        "\n",
        "**Status: Work in progress. Check back later.**\n",
        "\n",
        "In this notebook, we will use pre-trained deep learning model to process some text. We will then use the output of that model to classify the text. The text is a list of sentences from film reviews. And we will calssify each sentence as either speaking \"positively\" about its subject of \"negatively\".\n",
        "\n",
        "## Models\n",
        "The classification model we will use is logistic regression using the Scikit Learn library. The deep learning model is distilBERT, a high-performance version of the latest cutting-edge NLP model. We will use the implementation from the [huggingface transformers library](https://huggingface.co/).\n",
        "\n",
        "## Dataset\n",
        "The dataset we will use in this example is [SST2](https://nlp.stanford.edu/sentiment/index.html), which contains sentences from movie reviews, each labeled as either positive (has the value 1) or negative (has the value 0):\n",
        "\n",
        "\n",
        "<table class=\"features-table\">\n",
        "  <tr>\n",
        "    <th class=\"mdc-text-light-green-600\">\n",
        "    sentence\n",
        "    </th>\n",
        "    <th class=\"mdc-text-purple-600\">\n",
        "    label\n",
        "    </th>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td class=\"mdc-bg-light-green-50\" style=\"text-align:left\">\n",
        "      a stirring , funny and finally transporting re imagining of beauty and the beast and 1930s horror films\n",
        "    </td>\n",
        "    <td class=\"mdc-bg-purple-50\">\n",
        "      1\n",
        "    </td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td class=\"mdc-bg-light-green-50\" style=\"text-align:left\">\n",
        "      apparently reassembled from the cutting room floor of any given daytime soap\n",
        "    </td>\n",
        "    <td class=\"mdc-bg-purple-50\">\n",
        "      0\n",
        "    </td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td class=\"mdc-bg-light-green-50\" style=\"text-align:left\">\n",
        "      they presume their audience won't sit still for a sociology lesson\n",
        "    </td>\n",
        "    <td class=\"mdc-bg-purple-50\">\n",
        "      0\n",
        "    </td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td class=\"mdc-bg-light-green-50\" style=\"text-align:left\">\n",
        "      this is a visually stunning rumination on love , memory , history and the war between art and commerce\n",
        "    </td>\n",
        "    <td class=\"mdc-bg-purple-50\">\n",
        "      1\n",
        "    </td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td class=\"mdc-bg-light-green-50\" style=\"text-align:left\">\n",
        "      jonathan parker 's bartleby should have been the be all end all of the modern office anomie films\n",
        "    </td>\n",
        "    <td class=\"mdc-bg-purple-50\">\n",
        "      1\n",
        "    </td>\n",
        "  </tr>\n",
        "</table>\n",
        "\n",
        "## Installing the transformers library\n",
        "Let's start by installing the huggingface transformers library so we can load our deep learning NLP model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "To9ENLU90WGl",
        "colab_type": "code",
        "outputId": "4b46c997-c16c-4141-eaf2-e7aa7da6d3a0",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 632
        }
      },
      "source": [
        "!pip install transformers"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting transformers\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/fd/f9/51824e40f0a23a49eab4fcaa45c1c797cbf9761adedd0b558dab7c958b34/transformers-2.1.1-py3-none-any.whl (311kB)\n",
            "\r\u001b[K     |█                               | 10kB 15.1MB/s eta 0:00:01\r\u001b[K     |██                              | 20kB 2.2MB/s eta 0:00:01\r\u001b[K     |███▏                            | 30kB 3.2MB/s eta 0:00:01\r\u001b[K     |████▏                           | 40kB 2.1MB/s eta 0:00:01\r\u001b[K     |█████▎                          | 51kB 2.6MB/s eta 0:00:01\r\u001b[K     |██████▎                         | 61kB 3.1MB/s eta 0:00:01\r\u001b[K     |███████▍                        | 71kB 3.6MB/s eta 0:00:01\r\u001b[K     |████████▍                       | 81kB 4.1MB/s eta 0:00:01\r\u001b[K     |█████████▌                      | 92kB 4.5MB/s eta 0:00:01\r\u001b[K     |██████████▌                     | 102kB 3.5MB/s eta 0:00:01\r\u001b[K     |███████████▋                    | 112kB 3.5MB/s eta 0:00:01\r\u001b[K     |████████████▋                   | 122kB 3.5MB/s eta 0:00:01\r\u001b[K     |█████████████▊                  | 133kB 3.5MB/s eta 0:00:01\r\u001b[K     |██████████████▊                 | 143kB 3.5MB/s eta 0:00:01\r\u001b[K     |███████████████▊                | 153kB 3.5MB/s eta 0:00:01\r\u001b[K     |████████████████▉               | 163kB 3.5MB/s eta 0:00:01\r\u001b[K     |█████████████████▉              | 174kB 3.5MB/s eta 0:00:01\r\u001b[K     |███████████████████             | 184kB 3.5MB/s eta 0:00:01\r\u001b[K     |████████████████████            | 194kB 3.5MB/s eta 0:00:01\r\u001b[K     |█████████████████████           | 204kB 3.5MB/s eta 0:00:01\r\u001b[K     |██████████████████████          | 215kB 3.5MB/s eta 0:00:01\r\u001b[K     |███████████████████████▏        | 225kB 3.5MB/s eta 0:00:01\r\u001b[K     |████████████████████████▏       | 235kB 3.5MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▎      | 245kB 3.5MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▎     | 256kB 3.5MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▍    | 266kB 3.5MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▍   | 276kB 3.5MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▍  | 286kB 3.5MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▌ | 296kB 3.5MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▌| 307kB 3.5MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 317kB 3.5MB/s \n",
            "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.17.4)\n",
            "Collecting sentencepiece\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/14/3d/efb655a670b98f62ec32d66954e1109f403db4d937c50d779a75b9763a29/sentencepiece-0.1.83-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
            "\u001b[K     |████████████████████████████████| 1.0MB 53.8MB/s \n",
            "\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.10.14)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from transformers) (4.28.1)\n",
            "Collecting regex\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/e3/8e/cbf2295643d7265e7883326fb4654e643bfc93b3a8a8274d8010a39d8804/regex-2019.11.1-cp36-cp36m-manylinux1_x86_64.whl (643kB)\n",
            "\u001b[K     |████████████████████████████████| 645kB 39.9MB/s \n",
            "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)\n",
            "Collecting sacremoses\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/1f/8e/ed5364a06a9ba720fddd9820155cc57300d28f5f43a6fd7b7e817177e642/sacremoses-0.0.35.tar.gz (859kB)\n",
            "\u001b[K     |████████████████████████████████| 860kB 48.8MB/s \n",
            "\u001b[?25hRequirement already satisfied: botocore<1.14.0,>=1.13.14 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.13.14)\n",
            "Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.2.1)\n",
            "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.4)\n",
            "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2019.9.11)\n",
            "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n",
            "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.0)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.0)\n",
            "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.14->boto3->transformers) (0.15.2)\n",
            "Requirement already satisfied: python-dateutil<2.8.1,>=2.1; python_version >= \"2.7\" in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.14->boto3->transformers) (2.6.1)\n",
            "Building wheels for collected packages: sacremoses\n",
            "  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for sacremoses: filename=sacremoses-0.0.35-cp36-none-any.whl size=883999 sha256=cb6e175a99d8f14f69f593b734ea73303c23a7dbdc694be65435cfeebd1f3124\n",
            "  Stored in directory: /root/.cache/pip/wheels/63/2a/db/63e2909042c634ef551d0d9ac825b2b0b32dede4a6d87ddc94\n",
            "Successfully built sacremoses\n",
            "Installing collected packages: sentencepiece, regex, sacremoses, transformers\n",
            "Successfully installed regex-2019.11.1 sacremoses-0.0.35 sentencepiece-0.1.83 transformers-2.1.1\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fvFvBLJV0Dkv",
        "colab_type": "code",
        "outputId": "140119e5-4cee-4604-c0d2-be279c18b125",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 63
        }
      },
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.linear_model import LogisticRegression\n",
        "from sklearn.model_selection import GridSearchCV\n",
        "from sklearn.model_selection import cross_val_score\n",
        "import torch\n",
        "import transformers as ppb\n",
        "import warnings\n",
        "warnings.filterwarnings('ignore')"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<p style=\"color: red;\">\n",
              "The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.<br>\n",
              "We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now \n",
              "or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
              "<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zQ-42fh0hjsF",
        "colab_type": "text"
      },
      "source": [
        "## Importing the dataset\n",
        "We'll use pandas to read the dataset and load it into a dataframe."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cyoj29J24hPX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "df = pd.read_csv('https://github.com/clairett/pytorch-sentiment-classification/raw/master/data/SST2/train.tsv', delimiter='\\t', header=None)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dMVE3waNhuNj",
        "colab_type": "text"
      },
      "source": [
        "For performance reasons, we'll only use 2,000 sentences from the dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gTM3hOHW4hUY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "batch_1 = df[:2000]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PRc2L89hh1Tf",
        "colab_type": "text"
      },
      "source": [
        "We can ask pandas how many sentences are labeled as \"positive\" (value 1) and how many are labeled \"negative\" (having the value 0)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jGvcfcCP5xpZ",
        "colab_type": "code",
        "outputId": "4c4a8afc-1035-4b21-ba9a-c4bb6cfc6347",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        }
      },
      "source": [
        "batch_1[1].value_counts()"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "1    1041\n",
              "0     959\n",
              "Name: 1, dtype: int64"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7_MO08_KiAOb",
        "colab_type": "text"
      },
      "source": [
        "## Loading the Pre-trained BERT model\n",
        "Let's now load a pre-trained BERT model. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "q1InADgf5xm2",
        "colab_type": "code",
        "outputId": "dbc52856-4d52-42f8-8a74-a89944280a02",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        }
      },
      "source": [
        "model_class, tokenizer_class, pretrained_weights = (ppb.DistilBertModel, ppb.DistilBertTokenizer, 'distilbert-base-uncased')\n",
        "\n",
        "\n",
        "# Load pretrained model/tokenizer\n",
        "tokenizer = tokenizer_class.from_pretrained(pretrained_weights)\n",
        "model = model_class.from_pretrained(pretrained_weights)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 231508/231508 [00:00<00:00, 2649246.11B/s]\n",
            "100%|██████████| 492/492 [00:00<00:00, 284634.15B/s]\n",
            "100%|██████████| 267967963/267967963 [00:03<00:00, 72728701.55B/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lZDBMn3wiSX6",
        "colab_type": "text"
      },
      "source": [
        "Right now, the variable `model` holds a pretrained distilBERT model -- a version of BERT that is smaller, but much faster and requiring a lot less memory.\n",
        "\n",
        "## Model #1: Preparing the Dataset\n",
        "Before we can hand our sentences to BERT, we need to so some minimal processing to put them in the format it requires.\n",
        "\n",
        "### Tokenization\n",
        "Our first step is to tokenize the sentences -- break them up into word and subwords in the format BERT is comfortable with."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Dg82ndBA5xlN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tokenized = batch_1[0].apply((lambda x: tokenizer.encode(x, add_special_tokens=True)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mHwjUwYgi-uL",
        "colab_type": "text"
      },
      "source": [
        "### Padding\n",
        "After tokenization, `tokenized` is a list of sentences -- each sentences is represented as a list of tokens. We want BERT to process our examples all at once (as one batch). It's just faster that way. For that reason, we need to pad all lists to the same size, so we can represent the input as one 2-d array, rather than a list of lists (of different lengths)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "URn-DWJt5xhP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "max_len = 0\n",
        "for i in tokenized.values:\n",
        "    if len(i) > max_len:\n",
        "        max_len = len(i)\n",
        "\n",
        "padded = np.array([i + [0]*(max_len-len(i)) for i in tokenized.values])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mdjg306wjjmL",
        "colab_type": "text"
      },
      "source": [
        "Our dataset is now in the `padded` variable, we can view its dimensions below:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jdi7uXo95xeq",
        "colab_type": "code",
        "outputId": "be786022-e84f-4e28-8531-0143af2347bc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "np.array(padded).shape"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(2000, 59)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sDZBsYSDjzDV",
        "colab_type": "text"
      },
      "source": [
        "### Masking\n",
        "If we directly send `padded` to BERT, that would slightly confuse it. We need to create another variable to tell it to ignore (mask) the padding we've added when it's processing its input. That's what attention_mask is:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4K_iGRNa_Ozc",
        "colab_type": "code",
        "outputId": "d03b0a9b-1f6e-4e32-831e-b04f5389e57c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "attention_mask = np.where(padded != 0, 1, 0)\n",
        "attention_mask.shape"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(2000, 59)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jK-CQB9-kN99",
        "colab_type": "text"
      },
      "source": [
        "## Model #1: And Now, Deep Learning!\n",
        "Now that we have our model and inputs ready, let's run our model!\n",
        "\n",
        "The `model()` function runs our sentences through BERT. The results of the processing will be returned into `last_hidden_states`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "39UVjAV56PJz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "input_ids = torch.tensor(padded)  \n",
        "attention_mask = torch.tensor(attention_mask)\n",
        "\n",
        "with torch.no_grad():\n",
        "    last_hidden_states = model(input_ids, attention_mask=attention_mask)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FoCep_WVuB3v",
        "colab_type": "text"
      },
      "source": [
        "Let's slice only the part of the output that we need. That is the output corresponding the first token of each sentence. The way BERT does sentence classification, is that it adds a token called `[CLS]` (for classification) at the beginning of every sentence. The output corresponding to that token can be thought of as an embedding for the entire sentence.\n",
        "\n",
        "We'll save those in the `features` variable, as they'll serve as the features to our logitics regression model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C9t60At16PVs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "features = last_hidden_states[0][:,0,:].numpy()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_VZVU66Gurr-",
        "colab_type": "text"
      },
      "source": [
        "The labels indicating which sentence is positive and negative now go into the `labels` variable"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JD3fX2yh6PTx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "labels = batch_1[1]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iaoEvM2evRx1",
        "colab_type": "text"
      },
      "source": [
        "## Model #2: Train/Test Split\n",
        "Let's now split our datset into a training set and testing set (even though we're using 2,000 sentences from the SST2 training set)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ddAqbkoU6PP9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_features, test_features, train_labels, test_labels = train_test_split(features, labels)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B9bhSJpcv1Bl",
        "colab_type": "text"
      },
      "source": [
        "### [Bonus] Grid Search for Parameters\n",
        "We can dive into Logistic regression directly with the Scikit Learn default parameters, but sometimes it's worth searching for the best value of the C parameter, which determines regularization strength."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cyEwr7yYD3Ci",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# parameters = {'C': np.linspace(0.0001, 100, 20)}\n",
        "# grid_search = GridSearchCV(LogisticRegression(), parameters)\n",
        "# grid_search.fit(train_features, train_labels)\n",
        "\n",
        "# print('best parameters: ', grid_search.best_params_)\n",
        "# print('best scrores: ', grid_search.best_score_)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KCT9u8vAwnID",
        "colab_type": "text"
      },
      "source": [
        "We now train the LogisticRegression model. If you've chosen to do the gridsearch, you can plug the value of C into the model declaration (e.g. `LogisticRegression(C=5.2)`)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gG-EVWx4CzBc",
        "colab_type": "code",
        "outputId": "9252ceff-a7d0-4359-fef9-2f72be89c7d6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102
        }
      },
      "source": [
        "lr_clf = LogisticRegression()\n",
        "lr_clf.fit(train_features, train_labels)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
              "                   intercept_scaling=1, l1_ratio=None, max_iter=100,\n",
              "                   multi_class='warn', n_jobs=None, penalty='l2',\n",
              "                   random_state=None, solver='warn', tol=0.0001, verbose=0,\n",
              "                   warm_start=False)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 21
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3rUMKuVgwzkY",
        "colab_type": "text"
      },
      "source": [
        "## Evaluating Model #2\n",
        "So how well does our model do in classifying sentences? One way is to check the accuracy against the testing dataset:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iCoyxRJ7ECTA",
        "colab_type": "code",
        "outputId": "cfd86dea-5d16-476c-ab9b-47cbee3a014f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "lr_clf.score(test_features, test_labels)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.824"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "75oyhr3VxHoE",
        "colab_type": "text"
      },
      "source": [
        "How good is this score? What can we compare it against? Let's first look at a dummy classifier:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lnwgmqNG7i5l",
        "colab_type": "code",
        "outputId": "0042aed2-4fa8-4fa0-bf25-fdef70a10aac",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "from sklearn.dummy import DummyClassifier\n",
        "clf = DummyClassifier()\n",
        "\n",
        "scores = cross_val_score(clf, train_features, train_labels)\n",
        "print(\"Dummy classifier score: %0.3f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Dummy classifier score: 0.527 (+/- 0.05)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Lg4LOpoxSOR",
        "colab_type": "text"
      },
      "source": [
        "So our model clearly does better than a dummy classifier. But how does it compare against the best models?\n",
        "\n",
        "## Proper SST2 scores\n",
        "For reference, the [highest accuracy score](http://nlpprogress.com/english/sentiment_analysis.html) for this dataset is currently 96.8. DistilBERT can be trained to improve its score on this task – a process called **fine-tuning** which updates BERT’s weights to make it achieve a better performance in this sentence classification task (which we can call the downstream task). The fine-tuned DistilBERT turns out to achieve an accuracy score of 90.7. The full size BERT model achieves 94.9.\n",
        "\n",
        "## Finetuning\n",
        "Want the model to obtain a better score? the transformers package provides a script (`/examples/run_glue.py`) to do just that. <!--You just need to organize the data in the format it expects (sentences and labels), and run the script.--> Check the docs on the library for the exact way to run the finetuning, but at least for the current version (as of this writing), it's done in the following manner:\n",
        "\n",
        "First, download the GLUE dataset (huggingface provides [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) to do that).\n",
        "\n",
        "Then, install the additional packages required by the script\n",
        "```pip install -r ./examples/requirements.txt```\n",
        "\n",
        "Then run the script and supply it with the proper parameters. For sentence classification, we'll pass SST-2 as the task name. <!--You can run it directly if you want to finetune against the SST2 dataset. If you want to run it on your own classification dataset, replace the SST2 data with yours but keeping the same format and file names-->.\n",
        "```\n",
        "export GLUE_DIR=/path/to/glue\n",
        "export TASK_NAME=SST-2\n",
        "\n",
        "python ./examples/run_glue.py \\\n",
        "    --model_type bert \\\n",
        "    --model_name_or_path bert-base-uncased \\\n",
        "    --task_name $TASK_NAME \\\n",
        "    --do_train \\\n",
        "    --do_eval \\\n",
        "    --do_lower_case \\\n",
        "    --data_dir $GLUE_DIR/$TASK_NAME \\\n",
        "    --max_seq_length 128 \\\n",
        "    --per_gpu_eval_batch_size=8   \\\n",
        "    --per_gpu_train_batch_size=8   \\\n",
        "    --learning_rate 2e-5 \\\n",
        "    --num_train_epochs 3.0 \\\n",
        "    --output_dir /tmp/$TASK_NAME/\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EJQuqV6cnWQu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}